{
 "cells": [
  {
   "cell_type": "code",
   "execution_count": 12,
   "metadata": {},
   "outputs": [],
   "source": [
    "import torch\n",
    "from functools import reduce\n",
    "from pyraul.tools.dumping import print_torch_tensor, gen_cpp_dtVec\n",
    "from pyraul.tools.seed import set_seed"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 4,
   "metadata": {},
   "outputs": [],
   "source": [
    "def broadcast_test(size, broadcast_size, start=0):\n",
    "    r = reduce(lambda a,b: a*b, size, 1)\n",
    "    x = torch.arange(start, start+r).reshape(size)\n",
    "    y = torch.ones(broadcast_size)\n",
    "    z = x*y\n",
    "    print(\"x\", x.shape, x, sep=\"\\n\")\n",
    "    print(\"y\", y.shape, y, sep=\"\\n\")\n",
    "    print(\"z\", z.shape, z, sep=\"\\n\")\n",
    "    \n",
    "    xf = x.flatten()\n",
    "    zf = z.flatten()\n",
    "    for i, x in enumerate(xf):\n",
    "        print(f\"{i}: x={xf[i].item()}\")\n",
    "    for i, x in enumerate(zf):\n",
    "        print(f\"{i}: z={zf[i].item()}\")"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 5,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "x\n",
      "torch.Size([1, 2])\n",
      "tensor([[0, 1]])\n",
      "y\n",
      "torch.Size([2, 2])\n",
      "tensor([[1., 1.],\n",
      "        [1., 1.]])\n",
      "z\n",
      "torch.Size([2, 2])\n",
      "tensor([[0., 1.],\n",
      "        [0., 1.]])\n",
      "0: x=0\n",
      "1: x=1\n",
      "0: z=0.0\n",
      "1: z=1.0\n",
      "2: z=0.0\n",
      "3: z=1.0\n"
     ]
    }
   ],
   "source": [
    "broadcast_test(size=(1,2), broadcast_size=(2,2))"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 54,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "x\n",
      "torch.Size([2, 1])\n",
      "tensor([[0],\n",
      "        [1]])\n",
      "y\n",
      "torch.Size([2, 2])\n",
      "tensor([[1., 1.],\n",
      "        [1., 1.]])\n",
      "z\n",
      "torch.Size([2, 2])\n",
      "tensor([[0., 0.],\n",
      "        [1., 1.]])\n",
      "0: x=0\n",
      "1: x=1\n",
      "0: z=0.0\n",
      "1: z=0.0\n",
      "2: z=1.0\n",
      "3: z=1.0\n"
     ]
    }
   ],
   "source": [
    "broadcast_test(size=(2,1), broadcast_size=(2,2))"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 55,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "x\n",
      "torch.Size([2, 2])\n",
      "tensor([[0, 1],\n",
      "        [2, 3]])\n",
      "y\n",
      "torch.Size([2, 2])\n",
      "tensor([[1., 1.],\n",
      "        [1., 1.]])\n",
      "z\n",
      "torch.Size([2, 2])\n",
      "tensor([[0., 1.],\n",
      "        [2., 3.]])\n",
      "0: x=0\n",
      "1: x=1\n",
      "2: x=2\n",
      "3: x=3\n",
      "0: z=0.0\n",
      "1: z=1.0\n",
      "2: z=2.0\n",
      "3: z=3.0\n"
     ]
    }
   ],
   "source": [
    "broadcast_test(size=(2,2), broadcast_size=(2,2))"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 56,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "x\n",
      "torch.Size([2, 1, 2])\n",
      "tensor([[[0, 1]],\n",
      "\n",
      "        [[2, 3]]])\n",
      "y\n",
      "torch.Size([2, 2, 2])\n",
      "tensor([[[1., 1.],\n",
      "         [1., 1.]],\n",
      "\n",
      "        [[1., 1.],\n",
      "         [1., 1.]]])\n",
      "z\n",
      "torch.Size([2, 2, 2])\n",
      "tensor([[[0., 1.],\n",
      "         [0., 1.]],\n",
      "\n",
      "        [[2., 3.],\n",
      "         [2., 3.]]])\n",
      "0: x=0\n",
      "1: x=1\n",
      "2: x=2\n",
      "3: x=3\n",
      "0: z=0.0\n",
      "1: z=1.0\n",
      "2: z=0.0\n",
      "3: z=1.0\n",
      "4: z=2.0\n",
      "5: z=3.0\n",
      "6: z=2.0\n",
      "7: z=3.0\n"
     ]
    }
   ],
   "source": [
    "broadcast_test(size=(2,1,2), broadcast_size=(2,2,2))"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 30,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "x\n",
      "torch.Size([2, 2, 1])\n",
      "tensor([[[1],\n",
      "         [2]],\n",
      "\n",
      "        [[3],\n",
      "         [4]]])\n",
      "y\n",
      "torch.Size([2, 2, 2])\n",
      "tensor([[[1., 1.],\n",
      "         [1., 1.]],\n",
      "\n",
      "        [[1., 1.],\n",
      "         [1., 1.]]])\n",
      "z\n",
      "torch.Size([2, 2, 2])\n",
      "tensor([[[1., 1.],\n",
      "         [2., 2.]],\n",
      "\n",
      "        [[3., 3.],\n",
      "         [4., 4.]]])\n"
     ]
    }
   ],
   "source": [
    "broadcast_test(size=(2,2,1), broadcast_size=(2,2,2))"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 58,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "x\n",
      "torch.Size([1, 2, 2])\n",
      "tensor([[[0, 1],\n",
      "         [2, 3]]])\n",
      "y\n",
      "torch.Size([2, 2, 2])\n",
      "tensor([[[1., 1.],\n",
      "         [1., 1.]],\n",
      "\n",
      "        [[1., 1.],\n",
      "         [1., 1.]]])\n",
      "z\n",
      "torch.Size([2, 2, 2])\n",
      "tensor([[[0., 1.],\n",
      "         [2., 3.]],\n",
      "\n",
      "        [[0., 1.],\n",
      "         [2., 3.]]])\n",
      "0: x=0\n",
      "1: x=1\n",
      "2: x=2\n",
      "3: x=3\n",
      "0: z=0.0\n",
      "1: z=1.0\n",
      "2: z=2.0\n",
      "3: z=3.0\n",
      "4: z=0.0\n",
      "5: z=1.0\n",
      "6: z=2.0\n",
      "7: z=3.0\n"
     ]
    }
   ],
   "source": [
    "broadcast_test(size=(1,2,2), broadcast_size=(2,2,2))"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 7,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "x\n",
      "torch.Size([1, 1, 1, 1])\n",
      "tensor([[[[1]]]])\n",
      "y\n",
      "torch.Size([1, 2, 2, 1])\n",
      "tensor([[[[1.],\n",
      "          [1.]],\n",
      "\n",
      "         [[1.],\n",
      "          [1.]]]])\n",
      "z\n",
      "torch.Size([1, 2, 2, 1])\n",
      "tensor([[[[1.],\n",
      "          [1.]],\n",
      "\n",
      "         [[1.],\n",
      "          [1.]]]])\n",
      "0: x=1\n",
      "0: z=1.0\n",
      "1: z=1.0\n",
      "2: z=1.0\n",
      "3: z=1.0\n"
     ]
    }
   ],
   "source": [
    "broadcast_test(size=(1,1,1,1), broadcast_size=(1,2,2,1), start=1)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## Mul"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 19,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "torch.Size([5, 3, 2, 3])\n",
      "const raul::dtVec x{0.49625658988952637_dt, 0.7682217955589294_dt, 0.08847743272781372_dt, 0.13203048706054688_dt, 0.30742281675338745_dt, 0.6340786814689636_dt, 0.4900934100151062_dt, 0.8964447379112244_dt, 0.455627977848053_dt, 0.6323062777519226_dt, 0.3488934636116028_dt, 0.40171730518341064_dt, 0.022325754165649414_dt, 0.16885894536972046_dt, 0.2938884496688843_dt, 0.518521785736084_dt, 0.6976675987243652_dt, 0.800011396408081_dt, 0.16102945804595947_dt, 0.28226858377456665_dt, 0.6816085577011108_dt, 0.9151939749717712_dt, 0.39709991216659546_dt, 0.8741558790206909_dt, 0.41940832138061523_dt, 0.5529070496559143_dt, 0.9527381062507629_dt, 0.036164820194244385_dt, 0.1852310299873352_dt, 0.37341737747192383_dt};\n",
      "const raul::dtVec y{0.3051000237464905_dt, 0.9320003986358643_dt, 0.17591017484664917_dt, 0.2698335647583008_dt, 0.15067976713180542_dt, 0.03171950578689575_dt, 0.20812976360321045_dt, 0.9297990202903748_dt, 0.7231091856956482_dt, 0.7423362731933594_dt, 0.5262957811355591_dt, 0.24365824460983276_dt, 0.584592342376709_dt, 0.033152639865875244_dt, 0.13871687650680542_dt, 0.242235004901886_dt, 0.815468966960907_dt, 0.793160617351532_dt, 0.2782524824142456_dt, 0.48195880651474_dt, 0.8197803497314453_dt, 0.9970665574073792_dt, 0.6984410881996155_dt, 0.5675464272499084_dt, 0.8352431654930115_dt, 0.2055988311767578_dt, 0.593172013759613_dt, 0.11234724521636963_dt, 0.1534569263458252_dt, 0.24170821905136108_dt, 0.7262365221977234_dt, 0.7010802030563354_dt, 0.2038237452507019_dt, 0.6510535478591919_dt, 0.7744860053062439_dt, 0.4368913173675537_dt, 0.5190907716751099_dt, 0.6158523559570312_dt, 0.8101882934570312_dt, 0.9800970554351807_dt, 0.1146882176399231_dt, 0.3167651295661926_dt, 0.6965049505233765_dt, 0.9142746925354004_dt, 0.9351036548614502_dt};\n",
      "const raul::dtVec z{0.1514078974723816_dt, 0.7159830331802368_dt, 0.015564080327749252_dt, 0.04028250649571419_dt, 0.2865181863307953_dt, 0.11154089123010635_dt, 0.13390667736530304_dt, 0.11575548350811005_dt, 0.0028064604848623276_dt, 0.035626258701086044_dt, 0.04632239788770676_dt, 0.0201126616448164_dt, 0.1032857671380043_dt, 0.7142918705940247_dt, 0.06397884339094162_dt, 0.027479473501443863_dt, 0.2858414351940155_dt, 0.45850813388824463_dt, 0.363814115524292_dt, 0.47179508209228516_dt, 0.11101751029491425_dt, 0.4693838953971863_dt, 0.18362115323543549_dt, 0.0978817343711853_dt, 0.286504864692688_dt, 0.029719509184360504_dt, 0.06320329010486603_dt, 0.36964139342308044_dt, 0.011566739529371262_dt, 0.05572497099637985_dt, 0.11871778219938278_dt, 0.731022834777832_dt, 0.36138617992401123_dt, 0.15316671133041382_dt, 0.2845118045806885_dt, 0.318626344203949_dt, 0.006212196312844753_dt, 0.08138305693864822_dt, 0.24092397093772888_dt, 0.14427997171878815_dt, 0.336247056722641_dt, 0.6558336019515991_dt, 0.022260263562202454_dt, 0.11793802678585052_dt, 0.16679534316062927_dt, 0.5170007348060608_dt, 0.4872797131538391_dt, 0.45404359698295593_dt, 0.018647434189915657_dt, 0.03471720218658447_dt, 0.17432640492916107_dt, 0.43309178948402405_dt, 0.1434396356344223_dt, 0.47454437613487244_dt, 0.018091216683387756_dt, 0.0433160699903965_dt, 0.16475039720535278_dt, 0.10281952470541_dt, 0.060937732458114624_dt, 0.2112906575202942_dt, 0.11694547533988953_dt, 0.1978929191827774_dt, 0.13892801105976105_dt, 0.6646472811698914_dt, 0.27839890122413635_dt, 0.1781737208366394_dt, 0.10483880341053009_dt, 0.21861307322978973_dt, 0.2977888584136963_dt, 0.5958402752876282_dt, 0.30754831433296204_dt, 0.38191109895706177_dt, 0.21771098673343658_dt, 0.3405091166496277_dt, 0.7718972563743591_dt, 0.018772823736071587_dt, 0.11407496780157089_dt, 0.3025383949279785_dt, 0.4110608696937561_dt, 0.06341192126274109_dt, 0.3017942011356354_dt, 0.03544503450393677_dt, 0.02124381624162197_dt, 0.11828560382127762_dt, 0.29211997985839844_dt, 0.5055088996887207_dt, 0.8909088969230652_dt, 0.025188976898789406_dt, 0.16935203969478607_dt, 0.34918394684791565_dt};\n"
     ]
    }
   ],
   "source": [
    "set_seed(0)\n",
    "x = torch.rand(5,1,2,3)\n",
    "y = torch.rand(5,3,1,3)\n",
    "z=x*y\n",
    "print(z.shape)\n",
    "print(gen_cpp_dtVec(x.data.flatten(),\"x\"))\n",
    "print(gen_cpp_dtVec(y.data.flatten(),\"y\"))\n",
    "print(gen_cpp_dtVec(z.data.flatten(),\"z\"))"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 40,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "x (torch.Size([2, 1, 1, 1])):\n",
      "[0.49625658988952637, 0.7682217955589294]\n",
      "grad of x (torch.Size([2, 1, 1, 1])):\n",
      "[0.527930736541748, 0.527930736541748]\n",
      "y (torch.Size([1, 1, 1, 3])):\n",
      "[0.08847743272781372, 0.13203048706054688, 0.30742281675338745]\n",
      "grad of y (torch.Size([1, 1, 1, 3])):\n",
      "[1.2644784450531006, 1.2644784450531006, 1.2644784450531006]\n",
      "z (torch.Size([2, 1, 1, 3])):\n",
      "[0.04390750825405121, 0.0655210018157959, 0.15256059169769287, 0.06797029078006744, 0.10142869502305984, 0.23616890609264374]\n",
      "==============\n",
      "const raul::dtVec x{0.49625658988952637_dt, 0.7682217955589294_dt};\n",
      "const raul::dtVec y{0.08847743272781372_dt, 0.13203048706054688_dt, 0.30742281675338745_dt};\n",
      "const raul::dtVec z{0.04390750825405121_dt, 0.0655210018157959_dt, 0.15256059169769287_dt, 0.06797029078006744_dt, 0.10142869502305984_dt, 0.23616890609264374_dt};\n",
      "const raul::dtVec x_grad{0.527930736541748_dt, 0.527930736541748_dt};\n",
      "const raul::dtVec y_grad{1.2644784450531006_dt, 1.2644784450531006_dt, 1.2644784450531006_dt};\n"
     ]
    }
   ],
   "source": [
    "set_seed(0)\n",
    "x = torch.rand(2,1,1,1, requires_grad=True)\n",
    "y = torch.rand(1,1,1,3, requires_grad=True)\n",
    "z=x*y\n",
    "z.requires_grad_(True)\n",
    "z.sum().backward()\n",
    "print_torch_tensor(\"x\", x, grad=True)\n",
    "print_torch_tensor(\"y\", y, grad=True)\n",
    "print_torch_tensor(\"z\", z)\n",
    "print(\"==============\")\n",
    "print(gen_cpp_dtVec(x.data.flatten(),\"x\"))\n",
    "print(gen_cpp_dtVec(y.data.flatten(),\"y\"))\n",
    "print(gen_cpp_dtVec(z.data.flatten(),\"z\"))\n",
    "print(gen_cpp_dtVec(x.grad.flatten(),\"x_grad\"))\n",
    "print(gen_cpp_dtVec(y.grad.flatten(),\"y_grad\"))"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": []
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "Python 3",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.8.0"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 4
}
